import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from sklearn import preprocessing
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, roc_auc_score
import sys,os
sys.path.append(r"/home/yh579/GAFM/GAFM/models")
from bases import FirstNet,SecondNet,torch_auc,totalvaraition,Attacks
# SplitNN
import torch

#solve_isotropic_covariance
import math
import random
from collections import Counter
import numpy
device='cpu'
OBJECTIVE_EPSILON = 1e-14
CONVEX_EPSILON = 1e-8
NUM_CANDIDATE = 1


def symKL_objective(lam10, lam20, lam11, lam21, u, v, d, g):
    if (lam21 + v) == 0.0 or (lam20 + u) == 0.0 or (lam11 + v) == 0.0 or (lam10 + u) == 0.0:
        return float('inf')
    objective = (d - 1) * (lam20 + u) / (lam21 + v) \
                + (d - 1) * (lam21 + v) / (lam20 + u) \
                + (lam10 + u + g) / (lam11 + v) \
                + (lam11 + v + g) / (lam10 + u)
    return objective


def symKL_objective_zero_uv(lam10, lam11, g):
    objective = (lam10 + g) / lam11 \
                + (lam11 + g) / lam10
    return objective


def solve_isotropic_covariance(u, v, d, g, p, P,
                               lam10_init=None, lam20_init=None,
                               lam11_init=None, lam21_init=None):
    """ return the solution to the optimization problem
        Args:
        u ([type]): [the coordinate variance of the negative examples]
        v ([type]): [the coordinate variance of the positive examples]
        d ([type]): [the dimension of activation to protect]
        g ([type]): [squared 2-norm of g_0 - g_1, i.e. \|g^{(0)} - g^{(1)}\|_2^2]
        P ([type]): [the power constraint value]
    """

    if u == 0.0 and v == 0.0:
        return solve_zero_uv(g=g, p=p, P=P)

    ordering = [0, 1, 2]
    random.shuffle(x=ordering)

    solutions = []
    if u <= v:
        for i in range(NUM_CANDIDATE):
            if i % 3 == ordering[0]:
                # print('a')
                if lam20_init:  # if we pass an initialization
                    lam20 = lam20_init
                    # print('here')
                else:
                    lam20 = random.random() * P / (1 - p) / d
                lam10, lam11 = None, None
                # print('lam21', lam21)
            elif i % 3 == ordering[1]:
                # print('b')
                if lam11_init:
                    lam11 = lam11_init
                else:
                    lam11 = random.random() * P / p
                lam10, lam20 = None, None
                # print('lam11', lam11)
            else:
                # print('c')
                if lam10_init:
                    lam10 = lam10_init
                else:
                    lam10 = random.random() * P / (1 - p)
                lam11, lam20 = None, None
                # print('lam10', lam10)

            solutions.append(solve_small_neg(u=u, v=v, d=d, g=g, p=p, P=P, lam10=lam10, lam11=lam11, lam20=lam20))

    else:
        for i in range(NUM_CANDIDATE):
            if i % 3 == ordering[0]:
                if lam21_init:
                    lam21 = lam21_init
                else:
                    lam21 = random.random() * P / p / d
                lam10, lam11 = None, None
                # print('lam21', lam21)
            elif i % 3 == ordering[1]:
                if lam11_init:
                    lam11 = lam11_init
                else:
                    lam11 = random.random() * P / p
                lam10, lam21 = None, None
                # print('lam11', lam11)
            else:
                if lam10_init:
                    lam10 = lam10_init
                else:
                    lam10 = random.random() * P / (1 - p)
                lam11, lam21 = None, None
                # print('lam10', lam10)

            solutions.append(solve_small_pos(u=u, v=v, d=d, g=g, p=p, P=P, lam10=lam10, lam11=lam11, lam21=lam21))

    # print(solutions)
    lam10, lam20, lam11, lam21, objective = min(solutions, key=lambda x: x[-1])

    # print('sum', p * lam11 + p*(d-1)*lam21 + (1-p) * lam10 + (1-p)*(d-1)*lam20)

    return (lam10, lam20, lam11, lam21, objective)


def solve_zero_uv(g, p, P):
    C = P

    E = math.sqrt((C + (1 - p) * g) / (C + p * g))
    tau = max((P / (p)) / (E + (1 - p) /  (p)), 0.0)
    # print('tau', tau)
    if 0 <= tau and tau <= P / (1 - p):
        # print('A')
        lam10 = tau
        lam11 = max(P /  (p) - (1 - p) * tau /  (p), 0.0)
    else:
        # print('B')
        lam10_case1, lam11_case1 = 0.0, max(P /  (p), 0.0)
        lam10_case2, lam11_case2 = max(P / (1 - p), 0), 0.0
        objective1 = symKL_objective_zero_uv(lam10=lam10_case1, lam11=lam11_case1,
                                             g=g)
        objective2 = symKL_objective_zero_uv(lam10=lam10_case2, lam11=lam11_case2,
                                             g=g)
        if objective1 < objective2:
            lam10, lam11 = lam10_case1, lam11_case1
        else:
            lam10, lam11 = lam10_case2, lam11_case2

    objective = symKL_objective_zero_uv(lam10=lam10, lam11=lam11, g=g)
    # here we subtract d = 1 because the distribution is essentially one-dimensional
    return (lam10, 0.0, lam11, 0.0, 0.5 * objective - 1)


def solve_small_neg(u, v, d, g, p, P, lam10=None, lam20=None, lam11=None):
    """[When u < v]
    """
    # some intialization to start the alternating optimization
    LAM21 = 0.0
    i = 0
    objective_value_list = []

    if lam20:
        ordering = [0, 1, 2]
    elif lam11:
        ordering = [1, 0, 2]
    else:
        ordering = [1, 2, 0]
    # print(ordering)

    while True:
        if i % 3 == ordering[0]:  # fix lam20
            D = P - (1 - p) * (d - 1) * lam20
            C = D + p * v + (1 - p) * u

            E = math.sqrt((C + (1 - p) * g) / (C + p * g))
            tau = max((D / p + v - E * u) / (E + (1 - p) / p), 0.0)
            # print('tau', tau)
            if lam20 <= tau and tau <= P / (1 - p) - (d - 1) * lam20:
                # print('A')
                lam10 = tau
                lam11 = max(D / p - (1 - p) * tau / p, 0.0)
            else:
                # print('B')
                lam10_case1, lam11_case1 = lam20, max(P / p - (1 - p) * d * lam20 / p, 0.0)
                lam10_case2, lam11_case2 = max(P / (1 - p) - (d - 1) * lam20, 0), 0.0
                objective1 = symKL_objective(lam10=lam10_case1, lam20=lam20, lam11=lam11_case1, lam21=LAM21,
                                             u=u, v=v, d=d, g=g)
                objective2 = symKL_objective(lam10=lam10_case2, lam20=lam20, lam11=lam11_case2, lam21=LAM21,
                                             u=u, v=v, d=d, g=g)
                if objective1 < objective2:
                    lam10, lam11 = lam10_case1, lam11_case1
                else:
                    lam10, lam11 = lam10_case2, lam11_case2

        elif i % 3 == ordering[1]:  # fix lam11
            D = max((P - p * lam11) / (1 - p), 0.0)
            f = lambda x: symKL_objective(lam10=D - (d - 1) * x, lam20=x, lam11=lam11, lam21=LAM21,
                                          u=u, v=v, d=d, g=g)

            # f_prime = lambda x: (d-1)/v - (d-1)/(lam11+v) - (d-1)*v/((x+u)**2) + (lam11 + v + g)*(d-1)/((D-(d-1)*x+u)**2) # not numerically stable
            # f_prime = lambda x: (d-1)/v - (d-1)/(lam11+v) - (d-1)/(x+u)*(v/(x+u)) + (lam11 + v + g)/(D-(d-1)*x+u) * ((d-1)/(D-(d-1)*x+u))

            def f_prime(x):
                if x == 0.0 and u == 0.0:
                    return float('-inf')
                else:
                    return (d - 1) / v - (d - 1) / (lam11 + v) - (d - 1) / (x + u) * (v / (x + u)) + (lam11 + v + g) / (
                                D - (d - 1) * x + u) * ((d - 1) / (D - (d - 1) * x + u))

            # print('D/d', D/d)
            lam20 = convex_min_1d(xl=0.0, xr=D / d, f=f, f_prime=f_prime)
            lam10 = max(D - (d - 1) * lam20, 0.0)

        else:  # fix lam10
            D = max(P - (1 - p) * lam10, 0.0)  # avoid negative due to numerical error
            f = lambda x: symKL_objective(lam10=lam10, lam20=x, lam11=D / p - (1 - p) * (d - 1) * x / p, lam21=LAM21,
                                          u=u, v=v, d=d, g=g)

            # f_prime = lambda x: (d-1)/v - (1-p)*(d-1)/(lam10 + u)/p - (d-1)*v/((x+u)**2) + (lam10+u+g)*(1-p)*(d-1)/p/((D/p - (1-p)*(d-1)*x/p + v)**2) # not numerically stable
            # f_prime = lambda x: (d-1)/v - (1-p)*(d-1)/(lam10 + u)/p - (d-1)/(x+u)*(v/(x+u)) + (lam10+u+g)/(D/p - (1-p)*(d-1)*x/p + v) * (1-p) * (d-1) / p / (D/p - (1-p)*(d-1)*x/p + v)

            def f_prime(x):
                if x == 0.0 and u == 0.0:
                    return float('-inf')
                else:
                    return (d - 1) / v - (1 - p) * (d - 1) / (lam10 + u) / p - (d - 1) / (x + u) * (v / (x + u)) + (
                                lam10 + u + g) / (D / p - (1 - p) * (d - 1) * x / p + v) * (1 - p) * (d - 1) / p / (
                                       D / p - (1 - p) * (d - 1) * x / p + v)

            # print('lam10', 'D/((1-p)*(d-1)', lam10, D/((1-p)*(d-1)))
            lam20 = convex_min_1d(xl=0.0, xr=min(D / ((1 - p) * (d - 1)), lam10), f=f, f_prime=f_prime)
            lam11 = max(D / p - (1 - p) * (d - 1) * lam20 / p, 0.0)

        if lam10 < 0 or lam20 < 0 or lam11 < 0 or LAM21 < 0:  # check to make sure no negative values
            assert False, i

        objective_value_list.append(symKL_objective(lam10=lam10, lam20=lam20, lam11=lam11, lam21=LAM21,
                                                    u=u, v=v, d=d, g=g))
        # print(i)
        # print(objective_value_list[-1])
        # print(lam10, lam20, lam11, LAM21, objective_value_list[-1])
        # print('sum', p * lam11 + p*(d-1)*LAM21 + (1-p) * lam10 + (1-p)*(d-1)*lam20)

        if (i >= 3 and objective_value_list[-4] - objective_value_list[-1] < OBJECTIVE_EPSILON) or i >= 100:
            # print(i)
            return lam10, lam20, lam11, LAM21, 0.5 * objective_value_list[-1] - d

        i += 1


def solve_small_pos(u, v, d, g, p, P, lam10=None, lam11=None, lam21=None):
    """[When u > v] lam20 = 0.0 and will not change throughout the optimization
    """
    # some intialization to start the alternating optimization
    LAM20 = 0.0
    i = 0
    objective_value_list = []
    if lam21:
        ordering = [0, 1, 2]
    elif lam11:
        ordering = [1, 0, 2]
    else:
        ordering = [1, 2, 0]
    # print(ordering)
    while True:
        if i % 3 == ordering[0]:  # fix lam21
            D = P - p * (d - 1) * lam21
            C = D + p * v + (1 - p) * u

            E = math.sqrt((C + (1 - p) * g) / (C + p * g))
            tau = max((D / p + v - E * u) / (E + (1 - p) / p), 0.0)
            # print('tau', tau)
            if 0.0 <= tau and tau <= (P - p * d * lam21) / (1 - p):
                # print('A')
                lam10 = tau
                lam11 = max(D / (p) - (1 - p) * tau / (p), 0.0)
            else:
                # print('B')
                lam10_case1, lam11_case1 = 0, max(P / p - (d - 1) * lam21, 0.0)
                lam10_case2, lam11_case2 = max((P - p * d * lam21) / (1 - p), 0.0), lam21
                objective1 = symKL_objective(lam10=lam10_case1, lam20=LAM20, lam11=lam11_case1, lam21=lam21,
                                             u=u, v=v, d=d, g=g)
                objective2 = symKL_objective(lam10=lam10_case2, lam20=LAM20, lam11=lam11_case2, lam21=lam21,
                                             u=u, v=v, d=d, g=g)
                if objective1 < objective2:
                    lam10, lam11 = lam10_case1, lam11_case1
                else:
                    lam10, lam11 = lam10_case2, lam11_case2

        elif i % 3 == ordering[1]:  # fix lam11
            D = max(P - p * lam11, 0.0)
            f = lambda x: symKL_objective(lam10=(D - p * (d - 1) * x) / (1 - p), lam20=LAM20, lam11=lam11, lam21=x,
                                          u=u, v=v, d=d, g=g)

            # f_prime = lambda x: (d-1)/u - p*(d-1)/(lam11+v)/(1-p) - (d-1)*u/((x+v)**2) + (lam11 + v + g)*p*(d-1)/(1-p)/(((D - p*(d-1)*x)/(1-p) + u)**2) # not numerically stable
            # print('D', D)
            # print('P', P)
            # print('d', d)
            # print('u', u)
            # print('v', v)
            # print('g', g)
            # print('p', p)
            # print('lam11', lam11)
            # print()

            # f_prime = lambda x: (d-1)/u - p*(d-1)/(lam11+v)/(1-p) - (d-1)/(x+v)*(u/(x+v)) + (lam11 + v + g) / ((D - p*(d-1)*x)/(1-p) + u) * p * (d-1) / (1-p) /((D - p*(d-1)*x)/(1-p) + u)

            def f_prime(x):
                if x == 0.0 and v == 0.0:
                    return float('-inf')
                else:
                    return (d - 1) / u - p * (d - 1) / (lam11 + v) / (1 - p) - (d - 1) / (x + v) * (u / (x + v)) + (
                                lam11 + v + g) / ((D - p * (d - 1) * x) / (1 - p) + u) * p * (d - 1) / (1 - p) / (
                                       (D - p * (d - 1) * x) / (1 - p) + u)

            # print('lam11', 'D/p/(d-1)', lam11, D/p/(d-1))
            lam21 = convex_min_1d(xl=0.0, xr=min(D / p / (d - 1), lam11), f=f, f_prime=f_prime)
            lam10 = max((D - p * (d - 1) * lam21) / (1 - p), 0.0)

        else:  # fix lam10
            D = max((P - (1 - p) * lam10) / p, 0.0)
            f = lambda x: symKL_objective(lam10=lam10, lam20=LAM20, lam11=D - (d - 1) * x, lam21=x,
                                          u=u, v=v, d=d, g=g)

            # f_prime = lambda x: (d-1)/u - (d-1)/(lam10+u) - (d-1)*u/((x+v)**2) + (lam10 + u + g)*(d-1)/((D-(d-1)*x+v)**2)

            # print('D', D)
            # print('P', P)
            # print('d', d)
            # print('u', u)
            # print('v', v)
            # print('g', g)
            # print('p', p)
            # print('lam10', lam10)
            # print()

            # f_prime = lambda x: (d-1)/u - (d-1)/(lam10+u) - (d-1)/(x+v)*(u/(x+v)) + (lam10 + u + g)/(D-(d-1)*x+v) * (d-1) / (D-(d-1)*x+v)

            def f_prime(x):
                if x == 0.0 and v == 0.0:
                    return float('-inf')
                else:
                    return (d - 1) / u - (d - 1) / (lam10 + u) - (d - 1) / (x + v) * (u / (x + v)) + (lam10 + u + g) / (
                                D - (d - 1) * x + v) * (d - 1) / (D - (d - 1) * x + v)

            # def f_prime(x):
            #     print('x', x)
            #     print('d, u, v, g', d, u, v, g)
            #     print('(d-1)/u', (d-1)/u)
            #     print('(d-1)/(lam10+u)', (d-1)/(lam10+u))
            #     print('(d-1)*u/((x+v)**2)', (d-1)*u/((x+v)**2))
            #     print('(lam10 + u + g)*(d-1)/((D-(d-1)*x+v)**2)', (lam10 + u + g)*(d-1)/((D-(d-1)*x+v)**2))

            #     return (d-1)/u - (d-1)/(lam10+u) - (d-1)*u/((x+v)**2) + (lam10 + u + g)*(d-1)/((D-(d-1)*x+v)**2)
            # print('D/d', D/d)
            lam21 = convex_min_1d(xl=0.0, xr=D / d, f=f, f_prime=f_prime)
            lam11 = max(D - (d - 1) * lam21, 0.0)

        if lam10 < 0 or LAM20 < 0 or lam11 < 0 or lam21 < 0:
            assert False, i

        objective_value_list.append(symKL_objective(lam10=lam10, lam20=LAM20, lam11=lam11, lam21=lam21,
                                                    u=u, v=v, d=d, g=g))
        # print(i)
        # print(objective_value_list[-1])
        # print(lam10, LAM20, lam11, lam21)
        # print('sum', p * lam11 + p*(d-1)*lam21 + (1-p) * lam10 + (1-p)*(d-1)*LAM20)

        if (i >= 3 and objective_value_list[-4] - objective_value_list[-1] < OBJECTIVE_EPSILON) or i >= 100:
            # print(i)
            return lam10, LAM20, lam11, lam21, 0.5 * objective_value_list[-1] - d

        i += 1


def convex_min_1d(xl, xr, f, f_prime):
    # print('xl, xr', xl, xr)
    assert xr <= 1e5
    assert xl <= xr, (xl, xr)
    # print('xl, xr', xl, xr)

    xm = (xl + xr) / 2
    # print('xl', xl, f(xl), f_prime(xl))
    # print('xr', xr, f(xr), f_prime(xr))
    # print('xm', xm, f(xm), f_prime(xm))
    # print('abs(xl - xr) <= CONVEX_EPSILON',abs(xl - xr) <= CONVEX_EPSILON,abs(xl - xr) , CONVEX_EPSILON)
    if abs(xl - xr) <= CONVEX_EPSILON:
        # print('min((f(x), x) for x in [xl, xm, xr])[1]',min((f(x), x) for x in [xl, xm, xr])[1])
        return min((f(x), x) for x in [xl, xm, xr])[1]
    if f_prime(xl) <= 0 and f_prime(xr) <= 0:
        return xr
    elif f_prime(xl) >= 0 and f_prime(xr) >= 0:
        return xl
    if f_prime(xm) > 0:
        # print('xm', xm, f(xm), f_prime(xm))
        return convex_min_1d(xl=xl, xr=xm, f=f, f_prime=f_prime)
    else:
        # print('xm', xm, f(xm), f_prime(xm))
        return convex_min_1d(xl=xm, xr=xr, f=f, f_prime=f_prime)


def small_neg_problem_string(u, v, d, g, p, P):
    return 'minimize ({2}-1)*(z + {0})/{1} + ({2}-1)*{1}/(z+{0})+(x+{0}+{3})/(y+{1}) + (y+{1}+{3})/(x+{0}) subject to x>=0, y>=0, z>=0, z<=x, {4}*y+(1-{4})*x+(1-{4})*({2}-1)*z={5}'.format(
        u, v, d, g, p, P)


def small_pos_problem_string(u, v, d, g, p, P):
    return 'minimize ({2}-1)*{0}/(z+{1}) + ({2}-1)*(z + {1})/{0} + (x+{0}+{3})/(y+{1}) + (y+{1}+{3})/(x+{0}) subject to x>=0, y>=0, z>=0, z<=y, {4}*y+(1-{4})*x+{4}*({2}-1)*z={5}'.format(
        u, v, d, g, p, P)


def zero_uv_problem_string(g, p, P):
    return 'minimize (x+{0})/y + (y+{0})/x subject to x>=0, y>=0, {1}*y+(1-{1})*x={2}'.format(g, p, P)

def KL_gradient_perturb_function_creator(Y_Train,g,p_frac='pos_frac', dynamic=False, error_prob_lower_bound=None,
                                         sumKL_threshold=None, init_scale=1.0, uv_choice='uv'):
    # print('p_frac', p_frac)
    # print('dynamic', dynamic)
    if dynamic and (error_prob_lower_bound is not None):
        '''
        if using dynamic and error_prob_lower_bound is specified, we use it to 
        determine the sumKL_threshold and overwrite what is stored in it before.
        '''
        sumKL_threshold = (2 - 4 * error_prob_lower_bound) ** 2
        # print('error_prob_lower_bound', error_prob_lower_bound)
        # print('implied sumKL_threshold', sumKL_threshold)
    # elif dynamic:
    #     print('using sumKL_threshold', sumKL_threshold)

    # print('init_scale', init_scale)
    # print('uv_choice', uv_choice)

    y = list(Y_Train.iloc[:,0])
    pos, neg = [], []
    for i in range(len(y)):
        if y[i] == 1:
            pos.append(i)
        else:
            neg.append(i)
    # print('pos', pos)
    pos_g = [g[i] for i in pos]

    pos_g_mean =numpy.mean(pos_g)
    pos_coordinate_var=numpy.var(pos_g)
    neg_g =[g[i] for i in neg]
    neg_g_mean =numpy.mean(neg_g)
    neg_coordinate_var =numpy.var(neg_g)

    avg_pos_coordinate_var = numpy.mean(pos_coordinate_var)
    avg_neg_coordinate_var = numpy.mean(neg_coordinate_var)

    g_diff = pos_g_mean - neg_g_mean
    g_diff_norm = numpy.sqrt(g_diff**2)


    if uv_choice == 'uv':
        u = float(avg_neg_coordinate_var)
        v = float(avg_pos_coordinate_var)
        # if u == 0.0:
        #     print('neg_g')
        #     print(neg_g)
        # if v == 0.0:
        #     print('pos_g')
        #     print(pos_g)

    if uv_choice == 'same':
          u = float(avg_neg_coordinate_var + avg_pos_coordinate_var) / 2.0
          v = float(avg_neg_coordinate_var + avg_pos_coordinate_var) / 2.0
    elif uv_choice == 'zero':
          u, v = 0.0, 0.0

    d = len(Y_Train)
    if p_frac == 'pos_frac':
          p = float(numpy.sum(y) / len(y))  # p is set as the fraction of positive in the batch
    else:
          p = float(p_frac)

    scale = init_scale
    lam10, lam20, lam11, lam21 = None, None, None, None
    while True:
       P = scale * g_diff_norm ** 2
            # print('g_diff_norm ** 2', g_diff_norm ** 2)
            # print('P', P)
            # print('u, v, d, p', u, v, d, p)
       lam10, lam20, lam11, lam21, sumKL = \
                    solve_isotropic_covariance(
                        u=u,
                        v=v,
                        d=d,
                        g=g_diff_norm ** 2,
                        p=p,
                        P=P,
                        lam10_init=lam10,
                        lam20_init=lam20,
                        lam11_init=lam11,
                        lam21_init=lam21)
       if not dynamic or sumKL <= sumKL_threshold:break



    perturbed_g = g
    perturbed_g += numpy.multiply(numpy.random.normal(0,1,len(y)),
                                                      y) * g_diff * (
                                           numpy.sqrt(lam11 - lam21) / g_diff_norm)
    # print(',g,perturbed_g,lam11 , lam21,g_diff',g,perturbed_g,lam11 , lam21,g_diff)

    if lam21 > 0.0:
      perturbed_g += numpy.random.normal(0,1,len(y)) * y * numpy.sqrt(
                        lam21)
      # print('0 perturbed_g lam21',lam21,perturbed_g)

                    # negative examples add noise in g1 - g0
    perturbed_g += numpy.multiply(numpy.random.normal(0,1,len(y)),
                                                      [1-y[i] for i in range(len(y))]) * g_diff * (
                                           numpy.sqrt(lam10 - lam20) / g_diff_norm)
    #print('1 lam21,g_diff_norm,perturbed_g ,',lam21,g_diff_norm,perturbed_g)


                # add spherical noise to negative examples
    if lam20 > 0.0:
          perturbed_g +=numpy.random.normal(0,1,len(y)) *[1-y[i] for i in range(len(y))] * numpy.sqrt(
                        lam20)
          # print('2 perturbed_g',perturbed_g)
    return perturbed_g


# SplitNN
import torch


class Client_marvell(torch.nn.Module):
    def __init__(self, client_model):
        super().__init__()
        """class that expresses the Client on SplitNN
        Args:
            client_model (torch model): client-side model
        Attributes:
            client_model (torch model): cliet-side model
            client_side_intermidiate (torch.Tensor): output of
                                                     client-side model
            grad_from_server
        """

        self.client_model = client_model
        self.client_side_intermidiate = None
        self.grad_from_server = None

    def forward(self, inputs):
        """client-side feed forward network
        Args:
            inputs (torch.Tensor): the input data
        Returns:
            intermidiate_to_server (torch.Tensor): the output of client-side
                                                   model which the client sent
                                                   to the server
        """

        self.client_side_intermidiate = self.client_model(inputs)
        # send intermidiate tensor to the server
        intermidiate_to_server = self.client_side_intermidiate.detach() \
            .requires_grad_()

        return intermidiate_to_server

    def client_backward(self, grad_from_server):
        """client-side back propagation
        Args:
            grad_from_server: gradient which the server send to the client
        """
        self.grad_from_server = grad_from_server
        self.client_side_intermidiate.backward(grad_from_server)

    def train(self):
        self.client_model.train()

    def eval(self):
        self.client_model.eval()


class Server_marvell(torch.nn.Module):
    def __init__(self, server_model):
        super().__init__()
        """class that expresses the Server on SplitNN
        Args:
            server_model (torch model): server-side model
        Attributes:
            server_model (torch model): server-side model
            intermidiate_to_server:
            grad_to_client
        """
        self.server_model = server_model

        self.intermidiate_to_server = None
        self.grad_to_client = None

    def forward(self, intermidiate_to_server):
        """server-side training
        Args:
            intermidiate_to_server (torch.Tensor): the output of client-side
                                                   model
        Returns:
            outputs (torch.Tensor): outputs of server-side model
        """
        self.intermidiate_to_server = intermidiate_to_server
        # print('intermidiate_to_server',intermidiate_to_server)
        outputs = self.server_model(intermidiate_to_server)

        return outputs

    def server_backward(self):
        self.grad_to_client = self.intermidiate_to_server.grad.clone()
        return self.grad_to_client

    def train(self):
        self.server_model.train()

    def eval(self):
        self.server_model.eval()


class SplitNN_marvell(torch.nn.Module):
    def __init__(self, client, server,
                 client_optimizer, server_optimizer
                 ):
        super().__init__()
        """class that expresses the whole architecture of SplitNN
        Args:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        Attributes:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        """
        self.client = client
        self.server = server
        self.client_optimizer = client_optimizer
        self.server_optimizer = server_optimizer
        self.grad_to_client = None
        self.grad_to_client_1 = None

        self.intermidiate_to_server = None

    def forward(self, inputs, labels):
        # execute client - feed forward network
        self.intermidiate_to_server = self.client(inputs)
        # execute server - feed forward netwoek
        # g_o=self.intermidiate_to_server1.grad.clone()
        outputs = self.server(self.intermidiate_to_server)
        self.labels = pd.DataFrame(labels.detach().numpy())

        return outputs, self.intermidiate_to_server

    def backward(self):
        # execute server - back propagation
        self.grad_to_client_1 = self.server.server_backward()
        # print(' self.grad_to_client_1',self.grad_to_client_1)

        self.grad_to_client = KL_gradient_perturb_function_creator(self.labels,
                                                                   self.grad_to_client_1.detach().numpy().T[0],
                                                                   dynamic=False, error_prob_lower_bound=None,
                                                                   sumKL_threshold=None, init_scale=1.0, uv_choice='uv')
        # print(' self.grad_to_client',self.grad_to_client)
        # print('grad_to_client',grad_to_client)
        self.grad_to_client = torch.Tensor(self.grad_to_client).reshape(-1, 1)
        # execute client - back propagation
        # if model=='Marvell':
        # print('grad_to_client before',grad_to_client.T)
        # print('grad_to_client.detach().numpy().T[0]',grad_to_client.detach().numpy().T[0])

        # print('grad_to_client after',grad_to_client.T)

        self.client.client_backward(self.grad_to_client)

    def zero_grads(self):
        self.client_optimizer.zero_grad()
        self.server_optimizer.zero_grad()

    def step(self):
        self.client_optimizer.step()
        self.server_optimizer.step()

    def train(self):
        self.client.train()
        self.server.train()

    def eval(self):
        self.client.eval()
        self.server.eval()


def train_marvell(Epochs,features,train_loader,test_loader,lr=1e-5, info=True):
    input_dim = features.shape[-1]
    model_1 = FirstNet(input_dim)
    model_1 = model_1.to(device)

    model_2 = SecondNet()
    model_2 = model_2.to(device)

    model_1.double()
    model_2.double()

    opt_1 = optim.Adam(model_1.parameters(), lr=lr)
    opt_2 = optim.Adam(model_2.parameters(), lr=lr)

    BCE = nn.BCELoss()

    client = Client_marvell(model_1)
    server = Server_marvell(model_2)

    splitnn_marvell = SplitNN_marvell(client, server, opt_1, opt_2)
    splitnn_marvell.train()

    for epoch in range(Epochs):
        epoch_loss = 0
        epoch_outputs = []
        epoch_labels = []
        epoch_outputs_test = []
        epoch_labels_test = []
        epoch_g_norm = []
        epoch_g_mean = []
        epoch_g_inner = []
        epoch_g = []
        for i, data in enumerate(train_loader):
            splitnn_marvell.zero_grads()
            inputs, labels = data
            inputs = inputs.to(device).double()
            labels = labels.to(device).double()

            outputs, intermidiate_to_server = splitnn_marvell(inputs, labels)
            loss = BCE(outputs, labels)
            loss.backward()
            splitnn_marvell.backward()
            splitnn_marvell.step()

            epoch_loss += loss.item() / len(train_loader.dataset)

            epoch_outputs.append(outputs)
            epoch_labels.append(labels)

            grad_from_server = splitnn_marvell.client.grad_from_server
            g = list(grad_from_server.detach().numpy())
            g_norm = grad_from_server.pow(2).sum(dim=1).sqrt()
            v_1 = np.multiply(grad_from_server.detach().numpy(), labels.detach().numpy())
            mean_1 = v_1.sum() / len(v_1[v_1 != 0])
            mean_0 = (grad_from_server.detach().numpy().sum() -
                      v_1.sum()) / len(v_1[v_1 == 0])

            g_mean = []
            for a in g:
                if (a - mean_1) ** 2 < (a - mean_0) ** 2:
                    g_mean.append([1])
                else:
                    g_mean.append([0])
            g_mean = torch.tensor(g_mean)
            g_inner = []
            g = list(grad_from_server.detach().numpy())
            g_inner = []
            for a in g:
                if a > grad_from_server.median().item():
                    g_inner.append(1)
                else:
                    g_inner.append(0)
            g_inner = torch.tensor(g_inner)

            epoch_g_norm.append(g_norm)
            epoch_g_mean.append(g_mean)
            epoch_g_inner.append(g_inner)
            epoch_g.append(grad_from_server)

            t = next(iter(test_loader))
            outputs_test, _ = splitnn_marvell(t[0], t[1])
            labels_test = t[1]
            epoch_outputs_test.append(outputs_test)
            epoch_labels_test.append(labels_test)

            # print('labels',torch.cat(epoch_g_norm).shape)
            # print('epoch_g_norm',torch.cat(epoch_g_norm).shape)
            # print('epoch_g_norm',torch.cat(epoch_g_norm).shape)

        # print(intermidiate_gradients)
        # print(epoch_outputs)
        # print('epoch_g_norm',torch.cat(epoch_g_norm).shape)
        # print('epoch_g_mean',torch.cat(epoch_g_mean).shape)
        # print('epoch_labels',torch.cat(epoch_labels).shape)

        train_auc = torch_auc(torch.cat(epoch_labels),
                              torch.cat(epoch_outputs))
        test_auc = torch_auc(torch.cat(epoch_labels_test),
                             torch.cat(epoch_outputs_test))
        train_tvd = totalvaraition(torch.cat(epoch_labels),
                                   torch.cat(epoch_g))
        na_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm).view(-1, 1)),
                          1 - torch_auc(torch.cat(epoch_labels),
                                        torch.cat(epoch_g_norm).view(-1, 1)))
        ma_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean).view(-1, 1)),
                          1 - torch_auc(torch.cat(epoch_labels),
                                        torch.cat(epoch_g_mean).view(-1, 1)))
        cos_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_inner).view(-1, 1)))
        if info == True and (epoch % 10 == 0 or epoch == Epochs - 1):
            print('Epoch', epoch, 'Training Loss', epoch_loss,
                  'Training AUC', train_auc,
                  'Testing AUC', test_auc,
                  "TVD", train_tvd,
                  'NA Leak AUC', na_leak_auc,
                  'MA Leak AUC', ma_leak_auc,
                  'Median Leak AUC', cos_leak_auc
                  )
    return train_auc, test_auc, train_tvd, na_leak_auc, ma_leak_auc, cos_leak_auc, splitnn_marvell

#Multiple
# SplitNN

class SplitNN_marvell_multiple(torch.nn.Module):
    def __init__(self, clients, server,
                 clients_optimizers, server_optimizer, features
                 ):
        super().__init__()
        """class that expresses the whole architecture of SplitNN
        Args:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        Attributes:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        """
        self.clients = clients
        self.number = len(clients)
        self.server = server
        self.client_optimizers = clients_optimizers
        self.server_optimizer = server_optimizer
        self.grad_to_client = None
        self.intermidiate_to_server = 0
        self.features = features

    def forward(self, inputs, labels):

        intermidiate_to_servers = []
        # execute client - feed forward network
        self.labels = labels
        inter = int(inputs.shape[1] / self.number)
        # print('inputs.shape[0]',inputs.shape[1])
        # print('self.number',self.number)
        # print('inter',inter)
        self.intermidiate_to_server = 0
        for i in range(self.number):
            client = self.clients[i]
            input_data = inputs[:, self.features[i]:self.features[i + 1]]
            # input_data=inputs[:,i*inter:(i+1)*inter]

            # if i==self.number-1:
            #   # print('Last Client',i*inter,inputs.shape[1])
            #   input_data=inputs[:,i*inter:]
            # else:
            #   input_data=inputs[:,i*inter:(i+1)*inter]
            # print(i,self.number-1,client,input_data.shape,i*inter,(i+1)*inter)
            # print('inputs[:,i*inter:(i+1)*inter]',inputs[:,i*inter:(i+1)*inter])
            v = (client(input_data) / self.number)
            # print('v',v.shape)
            # print('client(inputs[:,i*inter:(i+1)*inter])/self.number',client(inputs[:,i*inter:(i+1)*inter])/self.number)
            # print('self.intermidiate_to_server',self.intermidiate_to_server)
            intermidiate_to_servers.append(v)
            self.intermidiate_to_server += v
        self.intermidiate_to_server.retain_grad()
        # execute server - feed forward netwoek .detach()
        # print('self.intermidiate_to_server',self.intermidiate_to_server)
        intermidiate_to_server = self.intermidiate_to_server
        # print('intermidiate_to_server.detach()',intermidiate_to_server.detach())
        outputs = self.server(self.intermidiate_to_server)
        self.labels = pd.DataFrame(labels.detach().numpy())
        # grad_to_client = self.server.server_backward(self.intermidiate_to_server)
        # grad_to_client = self.server.server_backward()

        return outputs, self.intermidiate_to_server, intermidiate_to_servers

    def backward(self):
        # execute server - back propagation
        self.grad_to_client = self.server.server_backward()
        # execute client - back propagation
        # if model=='Marvell':
        #   grad_to_client=KL_gradient_perturb_function_creator(self.label,grad_to_client)
        self.grad_to_client = KL_gradient_perturb_function_creator(self.labels,
                                                                   self.grad_to_client.detach().numpy().T[0],
                                                                   dynamic=False, error_prob_lower_bound=None,
                                                                   sumKL_threshold=None, init_scale=1.0, uv_choice='uv')
        self.grad_to_client = torch.Tensor(self.grad_to_client).reshape(-1, 1)

        for i in range(self.number):
            client = self.clients[i]
            client.client_backward(self.grad_to_client / self.number)

    def zero_grads(self):
        for i in range(self.number):
            self.client_optimizers[i].zero_grad()
            # self.client_optimizer.zero_grad()
        self.server_optimizer.zero_grad()

    def step(self):
        for i in range(self.number):
            self.client_optimizers[i].step()
        self.server_optimizer.step()

    def train(self):
        for i in range(self.number):
            self.clients[i].train()
        self.server.train()

    def eval(self):
        for i in range(self.number):
            self.clients[i].eval()
        self.server.eval()


def train_marvell_multiple(Epochs, features,train_loader,test_loader,lr=1e-5, info=False):
    model_client_1 = FirstNet(input_dim=features[1] - features[0])
    model_client_1 = model_client_1.to(device)
    model_client_2 = FirstNet(input_dim=features[2] - features[1])
    model_client_2 = model_client_2.to(device)
    model_client_3 = FirstNet(input_dim=features[3] - features[2])
    model_client_3 = model_client_3.to(device)

    model_client_1.double()
    model_client_2.double()
    model_client_3.double()
    model_clients = [model_client_1, model_client_2, model_client_3]

    client1 = Client_marvell(model_client_1)
    client2 = Client_marvell(model_client_2)
    client3 = Client_marvell(model_client_3)
    client = [client1, client2, client3]

    opt_c_1 = optim.Adam(model_client_1.parameters(), lr=lr)
    opt_c_2 = optim.Adam(model_client_2.parameters(), lr=lr)
    opt_c_3 = optim.Adam(model_client_3.parameters(), lr=lr)
    opt_c = [opt_c_1, opt_c_2, opt_c_3]

    model_2 = SecondNet()
    model_2 = model_2.to(device)
    model_2.double()
    opt_2 = optim.Adam(model_2.parameters(), lr=lr)
    server = Server_marvell(model_2)

    BCE = nn.BCELoss()
    splitnn = SplitNN_marvell_multiple(client, server, opt_c, opt_2, features)
    training_labels = []
    outputs_list = []
    intermediate_servers = []
    train_auc_list = []
    test_auc_list = []
    grads_vanilla = []
    na_leak_auc_list = []
    ma_leak_auc_list = []
    cos_leak_auc_list = []
    train_tvd_list = []
    splitnn.train()
    for epoch in range(Epochs):
        epoch_loss = 0
        epoch_outputs = []
        epoch_labels = []
        epoch_outputs_test = []
        epoch_labels_test = []
        epoch_g = []
        epoch_g_inner = []
        epoch_g_mean = []
        epoch_g_norm = []
        epoch_g1 = []
        epoch_g_inner1 = []
        epoch_g_mean1 = []
        epoch_g_norm1 = []
        epoch_g2 = []
        epoch_g_inner2 = []
        epoch_g_mean2 = []
        epoch_g_norm2 = []
        epoch_g3 = []
        epoch_g_inner3 = []
        epoch_g_mean3 = []
        epoch_g_norm3 = []
        for i, data in enumerate(train_loader):
            splitnn.zero_grads()

            inputs, labels = data
            inputs = inputs.to(device).double()
            labels = labels.to(device).double()

            outputs, intermidiate_to_server, intermidiate_to_servers = splitnn(inputs, labels)
            loss = BCE(outputs, labels)

            loss.backward(retain_graph=True)

            splitnn.backward()
            splitnn.step()
            # loss_D.backward()
            # opt_D.step()

            epoch_loss += (loss).item() / len(train_loader.dataset)
            epoch_outputs.append(outputs)
            epoch_labels.append(labels)
            g_norm, g_mean, g_inner = Attacks(splitnn.grad_to_client, labels)
            epoch_g_norm.append(g_norm)
            epoch_g_mean.append(g_mean)
            epoch_g_inner.append(g_inner)
            epoch_g.append(splitnn.grad_to_client)

            g_norm1, g_mean1, g_inner1 = Attacks(splitnn.grad_to_client / 3, labels)
            epoch_g_norm1.append(g_norm1)
            epoch_g_mean1.append(g_mean1)
            epoch_g_inner1.append(g_inner1)
            epoch_g1.append(intermidiate_to_server[0])

            g_norm2, g_mean2, g_inner2 = Attacks(splitnn.grad_to_client / 3, labels)
            epoch_g_norm2.append(g_norm2)
            epoch_g_mean2.append(g_mean2)
            epoch_g_inner2.append(g_inner2)
            epoch_g2.append(intermidiate_to_server[1])

            g_norm3, g_mean3, g_inner3 = Attacks(splitnn.grad_to_client / 3, labels)
            epoch_g_norm3.append(g_norm3)
            epoch_g_mean3.append(g_mean3)
            epoch_g_inner3.append(g_inner3)
            epoch_g3.append(intermidiate_to_server[2])

            t = next(iter(test_loader))
            outputs_test, _, _ = splitnn(t[0], t[1])
            labels_test = t[1]
            epoch_outputs_test.append(outputs_test)
            epoch_labels_test.append(labels_test)

        train_auc = torch_auc(torch.cat(epoch_labels),
                              torch.cat(epoch_outputs))
        test_auc = torch_auc(torch.cat(epoch_labels_test),
                             torch.cat(epoch_outputs_test))
        train_tvd = totalvaraition(torch.cat(epoch_labels),
                                   torch.cat(epoch_g))

        # train_auc=max(torch_auc(torch.cat(epoch_labels),
        #                             torch.cat(epoch_outputs)),1-torch_auc(torch.cat(epoch_labels),
        #                             torch.cat(epoch_outputs)))
        # test_auc=max(torch_auc(torch.cat(epoch_labels_test),
        #                             torch.cat(epoch_outputs_test)),1-torch_auc(torch.cat(epoch_labels_test),
        #                             torch.cat(epoch_outputs_test)))
        na_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm).view(-1, 1)),
                          1 - torch_auc(torch.cat(epoch_labels),
                                        torch.cat(epoch_g_norm).view(-1, 1)))
        ma_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean).view(-1, 1)),
                          1 - torch_auc(torch.cat(epoch_labels),
                                        torch.cat(epoch_g_mean).view(-1, 1)))
        cos_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_inner).view(-1, 1)))

        na_leak_auc1 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm1).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_norm1).view(-1, 1)))
        ma_leak_auc1 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean1).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_mean1).view(-1, 1)))
        cos_leak_auc1 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner1).view(-1, 1)),
                            1 - torch_auc(torch.cat(epoch_labels),
                                          torch.cat(epoch_g_inner1).view(-1, 1)))

        na_leak_auc2 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm2).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_norm2).view(-1, 1)))
        ma_leak_auc2 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean2).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_mean2).view(-1, 1)))
        cos_leak_auc2 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner2).view(-1, 1)),
                            1 - torch_auc(torch.cat(epoch_labels),
                                          torch.cat(epoch_g_inner2).view(-1, 1)))

        na_leak_auc3 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm3).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_norm3).view(-1, 1)))
        ma_leak_auc3 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean3).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_mean3).view(-1, 1)))
        cos_leak_auc3 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner3).view(-1, 1)),
                            1 - torch_auc(torch.cat(epoch_labels),
                                          torch.cat(epoch_g_inner3).view(-1, 1)))
        train_auc_list.append(train_auc)
        test_auc_list.append(test_auc)
        train_tvd_list.append(train_tvd)
        na_leak_auc_list.append(na_leak_auc)
        ma_leak_auc_list.append(ma_leak_auc)
        cos_leak_auc_list.append(cos_leak_auc)

        if (epoch % 10 == 0 or epoch == Epochs - 1):
            print('Epoch', epoch, 'Training Loss', epoch_loss,
                  'Training AUC', train_auc,
                  'Testing AUC', test_auc,
                  'TVD', train_tvd,
                  'NA Leak AUC', na_leak_auc,
                  'MA Leak AUC', ma_leak_auc,
                  'Cos Leak AUC', cos_leak_auc
                  )
            print('Client1',
                  'NA Leak AUC', na_leak_auc1,
                  'MA Leak AUC', ma_leak_auc1,
                  'Cos Leak AUC', cos_leak_auc1
                  )
            print('Client2',
                  'NA Leak AUC', na_leak_auc2,
                  'MA Leak AUC', ma_leak_auc2,
                  'Cos Leak AUC', cos_leak_auc2
                  )
            print('Client3',
                  'NA Leak AUC', na_leak_auc3,
                  'MA Leak AUC', ma_leak_auc3,
                  'Cos Leak AUC', cos_leak_auc3
                  )
        training_labels.append(labels)
        outputs_list.append(outputs)
        intermediate_servers.append(intermidiate_to_server)
        grads_vanilla.append(splitnn.grad_to_client)
    return train_auc, test_auc, train_tvd, na_leak_auc, ma_leak_auc, cos_leak_auc, na_leak_auc1, ma_leak_auc1, cos_leak_auc1, na_leak_auc2, ma_leak_auc2, cos_leak_auc2, na_leak_auc3, ma_leak_auc3, cos_leak_auc3, splitnn
